"""Render templates."""

import argparse
from importlib import resources
import itertools
import os
import pathlib
import re
import sys
import typing

from cki_lib import misc
from cki_lib import session
from cki_lib import yaml
import jinja2
from jinja2 import compiler
from jinja2 import utils
import sentry_sdk

from cki_tools.credentials import secrets

SESSION = session.get_session('cki_tools.render')


class _EscapingCodeGenerator(compiler.CodeGenerator):
    def visit_Template(self, *args, **kwargs):
        super().visit_Template(*args, **kwargs)
        self.writeline('escape = environment.filters["escape"]')


class Render:
    """Render a template via jinja2."""

    def __init__(
        self,
        template,
        *,
        autoescape,
        args,
        data_paths,
        raw_data_paths,
        include_path,
    ):
        # pylint: disable=too-many-arguments
        """Create an instance."""
        self.data_paths = data_paths
        self.raw_data_paths = raw_data_paths
        self.include_path = include_path

        self.env = jinja2.Environment(loader=jinja2.FunctionLoader(self.load_template),
                                      undefined=jinja2.StrictUndefined,
                                      keep_trailing_newline=True,
                                      autoescape=autoescape,
                                      extensions=['jinja2.ext.do'])

        self.env.filters['env'] = Render.filter_env
        self.env.filters['is_true'] = misc.strtobool
        self.env.filters['from_yaml'] = lambda d: yaml.load(contents=d)
        self.env.filters['to_kebab_case'] = misc.to_kebab_case
        self.env.filters['to_pascal_case'] = misc.to_pascal_case
        self.env.filters['to_snake_case'] = misc.to_snake_case
        self.env.filters['to_screaming_snake_case'] = misc.to_screaming_snake_case

        self.env.globals['cki_variable'] = secrets.variable
        self.env.globals['cki_secret'] = secrets.secret
        self.env.globals['url'] = Render.global_url
        self.env.globals['env'] = os.environ
        self.env.globals['is_production'] = misc.is_production()
        self.env.globals['is_production_or_staging'] = misc.is_production_or_staging()
        self.env.globals['is_staging'] = misc.is_staging()
        self.env.globals['deployment_environment'] = misc.deployment_environment()
        self.env.globals['bucket_spec'] = Render.global_bucket_spec

        if autoescape:
            self.env.filters['e'] = Render.filter_escape_json
            self.env.filters['escape'] = Render.filter_escape_json
            self.env.filters['forceescape'] = lambda v: Render.filter_escape_json(str(v))
            self.env.filters['tojson'] = lambda v: str(utils.htmlsafe_json_dumps(v))
            self.env.code_generator_class = _EscapingCodeGenerator
            self.env.finalize = utils.pass_context(lambda _, v: v)

        self.data_file_contents = {
            '__template': template,
        }
        self.data = dict(args or {})

    @staticmethod
    def filter_escape_json(value):
        """Escape any unsafe values as JSON."""
        if hasattr(value, "__html__"):
            return value
        return utils.htmlsafe_json_dumps(value)

    @staticmethod
    def global_bucket_spec(url_variable_name: str, secret_name: typing.Optional[str] = None) -> str:
        """Dynamically create a bucket spec.

        This method provides a migration path from pre-assembled bucket specs
        to explicit credentials and should not be used in new code.
        """
        scheme, _, host, bucket, prefix = secrets.variable(url_variable_name).split('/', 4)
        endpoint_url = f'{scheme}//{host}'
        if secret_name:
            secret_data = secrets.secret(f'{secret_name}[deployed]:')[0]
            secret_meta = secrets.secret(f'{secret_name}[deployed]#')[0]
            if (secret_endpoint_url := secret_meta.get('endpoint_url', '')):
                if secret_endpoint_url != endpoint_url:
                    # pylint: disable=broad-exception-raised
                    raise Exception(f'Endpoint mismatch: {secret_endpoint_url} != {endpoint_url}')
            access_key_id = secret_meta['access_key_id']
            secret_access_key = secret_data['value']
        else:
            access_key_id = ''
            secret_access_key = ''
        return f'{endpoint_url}|{access_key_id}|{secret_access_key}|{bucket}|{prefix}'

    @staticmethod
    def global_url(url, binary=False, json=False):
        """Try to get a file via requests."""
        response = SESSION.get(url)
        response.raise_for_status()
        return response.content if binary else response.json() if json else response.text

    @staticmethod
    def filter_env(value):
        """Try to get a variable from env if the value matches ${*}."""
        if not isinstance(value, str):
            return value

        # Try to match ${*} with or without brackets.
        return re.sub(
            r'\$(\w+|\{([^}]*)\})',
            lambda m: os.environ[m.group(2) or m.group(1)],
            value
        )

    def load_template(self, name):
        """Feed a template into jinja2."""
        if name in self.data_file_contents:
            return self.data_file_contents[name]
        for path in (self.include_path or []):
            template_path = pathlib.Path(path) / name
            if template_path.exists():
                return template_path.read_text()
        return None

    @staticmethod
    def _clean_path(path, parse):
        return re.sub(r'\..*', '', path) if parse else re.sub(r'\.j2$', '', path)

    def _read_file(self, file_path, parse):
        contents = file_path.read_text()
        if file_path.suffix == '.j2':
            data_contents_key = '__' + file_path.as_posix()
            self.data_file_contents[data_contents_key] = contents
            contents = self.env.get_template(data_contents_key).render(self.data)
        return yaml.load(contents=contents) if parse else contents

    def _read_dir(self, dir_path, parse):
        data = {}
        extensions = ['.yml', '.yml.j2', '.yaml', '.yaml.j2',
                      '.json', '.json.j2'] if parse else ['']
        for file_path in itertools.chain(*[dir_path.glob(f'**/*{t}') for t in extensions]):
            if not file_path.is_file():
                continue
            parts = file_path.relative_to(dir_path).parts
            sub_data = data
            for part in parts[:-1]:
                sub_data = sub_data.setdefault(self._clean_path(part, parse), {})
            sub_data[self._clean_path(parts[-1], parse)] = self._read_file(file_path, parse)
        return data

    def _read(self, path_string, parse):
        path = pathlib.Path(path_string)
        return self._read_file(path, parse) if path.is_file() else self._read_dir(path, parse)

    def render(self):
        """Render the template."""
        for name, path in (self.data_paths or {}).items():
            self.data[name] = self._read(path, parse=True)
        for name, path in (self.raw_data_paths or {}).items():
            self.data[name] = self._read(path, parse=False)

        return self.env.get_template('__template').render(self.data)


def main(args):
    """Run the main CLI interface."""
    parser = argparse.ArgumentParser(
        description='Render templates',
    )
    parser.add_argument('template', nargs='?',
                        help='template file to render')
    parser.add_argument('--package',
                        help='Get the template from a package instead of from the file system')
    parser.add_argument('--data', action=misc.StoreNameValuePair,
                        default={}, metavar='NAME=PATH',
                        help='JSON/YAML data file(s) to expose as NAME')
    parser.add_argument('--raw-data', action=misc.StoreNameValuePair,
                        default={}, metavar='NAME=PATH',
                        help='raw data file(s) to expose as NAME')
    parser.add_argument('--arg', action=misc.StoreNameValuePair,
                        default={}, metavar='NAME=VALUE',
                        help='variable value to expose as NAME')
    parser.add_argument('--include-path', action='append', default=[],
                        help='path for template includes')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout,
                        help='Output path for the rendered file.')
    parser.add_argument('--autoescape', action=argparse.BooleanOptionalAction, default=True,
                        help='Automatically escape as JSON')
    parsed_args = parser.parse_args(args)

    if not parsed_args.template:
        template = sys.stdin.read()
    elif parsed_args.package:
        template = resources.files(parsed_args.package).joinpath(
            parsed_args.template).read_text(encoding='utf8')
    else:
        template = pathlib.Path(parsed_args.template).read_text(encoding='utf8')
    with parsed_args.output as output_file:
        result = Render(template,
                        autoescape=parsed_args.autoescape,
                        args=parsed_args.arg,
                        data_paths=parsed_args.data,
                        raw_data_paths=parsed_args.raw_data,
                        include_path=parsed_args.include_path).render()
        output_file.write(result)


if __name__ == '__main__':
    misc.sentry_init(sentry_sdk)
    main(sys.argv[1:])
